package com.webank.maling.ai.documentation;

import com.webank.maling.ai.config.BatchProcessingConfig;
import com.webank.maling.ai.support.TokenSupport;
import com.webank.maling.base.model.MethodInfo;
import jakarta.annotation.PreDestroy;
import lombok.extern.slf4j.Slf4j;
import org.springframework.ai.chat.client.ChatClient;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Qualifier;
import org.springframework.stereotype.Service;

import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.stream.Collectors;

import static org.springframework.ai.chat.client.advisor.AbstractChatMemoryAdvisor.CHAT_MEMORY_CONVERSATION_ID_KEY;

/**
 * 分批文档生成服务
 * 处理大内容的分批生成和合并
 *
 * @author diodehe
 */
@Slf4j
@Service
public class BatchDocumentationService {

    private final PromptBuilder promptBuilder;
    private final ChatClient simpleChatClient;
    private final ChatClient multiConversationChatClient;
    private final ExecutorService executorService;

    @Autowired
    public BatchDocumentationService(
            @Qualifier("simpleChatClient") ChatClient simpleChatClient,
            @Qualifier("multiConversationChatClient") ChatClient multiConversationChatClient,
            PromptBuilder promptBuilder) {
        this.promptBuilder = promptBuilder;

        this.simpleChatClient = simpleChatClient;
        this.multiConversationChatClient = multiConversationChatClient;

        this.executorService = Executors.newFixedThreadPool(BatchProcessingConfig.getMaxConcurrentBatches());
        log.info("分批文档生成服务已初始化，最大并发数: {}，支持多轮对话模式", BatchProcessingConfig.getMaxConcurrentBatches());
    }

    @PreDestroy
    public void shutdown() {
        if (executorService != null && !executorService.isShutdown()) {
            executorService.shutdown();
            log.info("分批文档生成服务已关闭");
        }
    }

    /**
     * 检查是否需要分批处理
     *
     * @param context 文档生成上下文
     * @return 是否需要分批处理
     */
    public boolean needsBatchProcessing(DocumentationGenerationContext context) {
        // 获取当前层级的方法
        List<MethodInfo> methods = context.getMethodsAtLevel();

        // 获取模型配置
        int maxTokens = BatchProcessingConfig.getMaxTokensForLevel(context.getLevel());

        int estimatedTokens = TokenSupport.calculateTokens(methods, ModelConfig.getDefaultConfig().getEncodingType());

        boolean needsBatch = estimatedTokens > maxTokens;
        if (needsBatch) {
            log.info("第{}层内容过大({} tokens > {} tokens)，需要分批处理",
                    context.getLevel(), estimatedTokens, maxTokens);
        }

        return needsBatch;
    }

    /**
     * 分批生成文档（支持所有层级）
     * 根据配置选择使用传统并行模式或多轮对话模式
     *
     * @param context 文档生成上下文
     * @return 合并后的完整文档内容
     */
    public String generateBatchDocumentation(DocumentationGenerationContext context) {
        // 可以通过配置决定使用哪种模式，这里默认使用多轮对话模式
        return generateConversationalDocumentation(context);
//        boolean useConversationalMode = shouldUseConversationalMode(context);
//
//        if (useConversationalMode) {
//            return generateConversationalDocumentation(context);
//        } else {
//            return generateTraditionalBatchDocumentation(context);
//        }
    }


    /**
     * 多轮对话式文档生成
     * 每轮对话都携带上一轮的总结内容和新的代码信息
     *
     * @param context 文档生成上下文
     * @return 完整的文档内容
     */
    public String generateConversationalDocumentation(DocumentationGenerationContext context) {
        try {
            log.info("开始多轮对话式生成第{}层文档，入口点: {}", context.getLevel(), context.getEntryPointId());

            // 生成唯一的对话ID
            String conversationId = generateConversationId(context);

            // 1. 构建分批提示词
            List<String> batchPrompts = promptBuilder.buildBatchPrompts(context);

            if (batchPrompts.isEmpty()) {
                log.error("构建分批提示词失败");
                return null;
            }

            int size = batchPrompts.size();

            log.info("将第{}层文档分为 {} 轮对话处理", context.getLevel(), size);

            // 2. 初始化对话 - 设置整体任务背景
            initializeConversation(conversationId, context, size);

            // 3. 逐轮进行对话，每轮都基于前面的上下文
            StringBuilder finalDocumentation = new StringBuilder();

            for (int i = 0; i < size; i++) {
                log.info("开始第 {} 轮对话 (共 {} 轮)", i + 1, size);

                String roundResponse = callConversationalAI(conversationId, batchPrompts.get(i));

                if (roundResponse != null && !roundResponse.trim().isEmpty()) {
                    log.info("完成第 {} 轮对话，内容长度: {} 字符", i + 1, roundResponse.length());

                    // 如果是最后一轮，请求AI生成最终完整文档
                    if (i == size - 1) {
                        String finalPrompt = buildFinalSummaryPrompt(context);
                        String finalResponse = callConversationalAI(conversationId, finalPrompt);
                        finalDocumentation.append(finalResponse != null ? finalResponse : roundResponse);
                    }
                } else {
                    log.warn("第 {} 轮对话返回空内容", i + 1);
                }
            }

            // 4. 清理对话历史（可选）
            cleanupConversation(conversationId);

            String result = finalDocumentation.toString();
            log.info("完成第{}层多轮对话文档生成，最终内容长度: {} 字符",
                    context.getLevel(), result.length());

            return result;

        } catch (Exception e) {
            log.error("多轮对话生成文档时发生错误", e);
            return null;
        }
    }

    /**
     * 生成对话ID
     */
    private String generateConversationId(DocumentationGenerationContext context) {
        return String.format("doc_gen_%s_level_%d_%d",
                context.getEntryPointId(),
                context.getLevel(),
                System.currentTimeMillis());
    }

    /**
     * 初始化对话，设置整体任务背景
     */
    private void initializeConversation(String conversationId, DocumentationGenerationContext context, int batchCount) {
        String initPrompt = promptBuilder.buildInitializationPromptString(context, batchCount);

        try {
            String response = multiConversationChatClient.prompt()
                    .user(initPrompt)
                    .advisors(a -> a.param(CHAT_MEMORY_CONVERSATION_ID_KEY, conversationId))
                    .call()
                    .content();

            log.debug("对话初始化完成，AI响应: {}", response);
        } catch (Exception e) {
            log.warn("对话初始化失败", e);
        }
    }

    /**
     * 构建最终总结提示词
     */
    private String buildFinalSummaryPrompt(DocumentationGenerationContext context) {
        return promptBuilder.buildFinalIntegrationPromptString(context);
    }

    /**
     * 调用支持记忆的AI服务
     */
    private String callConversationalAI(String conversationId, String prompt) {
        try {
            return multiConversationChatClient.prompt()
                    .user(prompt)
                    .advisors(a -> a.param(CHAT_MEMORY_CONVERSATION_ID_KEY, conversationId))
                    .call()
                    .content();
        } catch (Exception e) {
            log.error("调用对话式AI服务失败", e);
            return null;
        }
    }

    /**
     * 清理对话历史
     */
    private void cleanupConversation(String conversationId) {
        try {
            // 可以选择清理对话历史以释放内存
            // chatMemory.clear(conversationId);
            log.debug("对话历史清理完成: {}", conversationId);
        } catch (Exception e) {
            log.warn("清理对话历史失败: {}", conversationId, e);
        }
    }

    /**
     * 传统的并行分批文档生成（保持向后兼容）
     *
     * @param context 文档生成上下文
     * @return 合并后的完整文档内容
     */
    public String generateTraditionalBatchDocumentation(DocumentationGenerationContext context) {
        try {
            log.info("开始传统分批生成第{}层文档，入口点: {}", context.getLevel(), context.getEntryPointId());

            // 1. 构建分批提示词
            List<String> batchPrompts = promptBuilder.buildBatchPrompts(context);

            if (batchPrompts.isEmpty()) {
                log.error("构建分批提示词失败");
                return null;
            }

            log.info("将第{}层文档分为 {} 批处理", context.getLevel(), batchPrompts.size());

            // 2. 并行生成各批次内容
            List<CompletableFuture<String>> futures = new ArrayList<>();

            for (int i = 0; i < batchPrompts.size(); i++) {
                final int batchIndex = i;
                final String prompt = batchPrompts.get(i);

                CompletableFuture<String> future = CompletableFuture.supplyAsync(() -> {
                    try {
                        log.info("开始生成第 {} 批文档内容", batchIndex + 1);
                        String content = callAIServiceWithRetry(prompt);
                        log.info("完成第 {} 批文档生成，内容长度: {} 字符", batchIndex + 1,
                                content != null ? content.length() : 0);
                        return content;
                    } catch (Exception e) {
                        log.error("生成第 {} 批文档失败", batchIndex + 1, e);
                        return null;
                    }
                }, executorService);

                futures.add(future);
            }

            // 3. 等待所有批次完成并收集结果
            List<String> batchResults = futures.stream()
                    .map(CompletableFuture::join)
                    .collect(Collectors.toList());

            // 4. 合并批次结果
            String mergedContent = mergeBatchResults(batchResults, context);

            log.info("完成第{}层传统分批文档生成，最终内容长度: {} 字符",
                    context.getLevel(), mergedContent != null ? mergedContent.length() : 0);

            return mergedContent;

        } catch (Exception e) {
            log.error("传统分批生成文档时发生错误", e);
            return null;
        }
    }

    /**
     * 带重试机制的AI服务调用
     */
    private String callAIServiceWithRetry(String prompt) {
        Exception lastException = null;
        int maxRetries = 3;
        long retryDelay = 1000;

        for (int attempt = 1; attempt <= maxRetries; attempt++) {
            try {
                log.debug("第 {} 次尝试调用AI服务", attempt);

                String response = simpleChatClient.prompt()
                        .user(prompt)
                        .call()
                        .content();

                if (response != null && !response.trim().isEmpty()) {
                    log.debug("AI服务调用成功，第 {} 次尝试", attempt);
                    return response;
                }

                log.warn("AI服务返回空内容，第 {} 次尝试", attempt);

            } catch (Exception e) {
                lastException = e;
                log.warn("AI服务调用失败，第 {} 次尝试: {}", attempt, e.getMessage());

                if (attempt < maxRetries) {
                    try {
                        Thread.sleep(retryDelay * attempt);
                    } catch (InterruptedException ie) {
                        Thread.currentThread().interrupt();
                        break;
                    }
                }
            }
        }

        log.error("AI服务调用失败，已重试 {} 次", maxRetries, lastException);
        return null;
    }

    /**
     * 合并批次结果
     */
    private String mergeBatchResults(List<String> batchResults, DocumentationGenerationContext context) {
        if (batchResults == null || batchResults.isEmpty()) {
            return null;
        }

        // 过滤掉空结果
        List<String> validResults = batchResults.stream()
                .filter(result -> result != null && !result.trim().isEmpty())
                .toList();

        if (validResults.isEmpty()) {
            log.warn("所有批次结果都为空");
            return null;
        }

        // 简单合并策略：用分隔符连接
        StringBuilder merged = new StringBuilder();
        merged.append("# 第").append(context.getLevel()).append("层级完整说明书\n\n");

        for (int i = 0; i < validResults.size(); i++) {
            if (i > 0) {
                merged.append("\n---\n\n");
            }
            merged.append("## 第").append(i + 1).append("部分\n\n");
            merged.append(validResults.get(i));
        }

        return merged.toString();
    }
}
